{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3cd451fd",
   "metadata": {},
   "source": [
    "# The Optimal BERT Surgeon: Scalable and Accurate Second-Order Pruning for Large Language Models (oBERT)\n",
    "\n",
    "### Paper: [https://arxiv.org/abs/2203.07259](https://arxiv.org/abs/2203.07259)\n",
    "\n",
    "The oBERT implementation is integrated with the SparseML library in the form of [OBSPruningModifier](https://github.com/neuralmagic/sparseml/blob/main/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py), making it very easy to run experiments with, reproduce results from the paper or even compress new models.\n",
    "We also provide [bash scripts](https://github.com/neuralmagic/sparseml/tree/main/research/optimal_BERT_surgeon_oBERT/scripts) and [recipes](https://github.com/neuralmagic/sparseml/tree/main/research/optimal_BERT_surgeon_oBERT/recipes) used to produce results from the paper, and they can be easily modified to encompass new models and datasets.\n",
    "\n",
    "Here, we extract the algoritmic part for oBERT unstructured pruning from the OBSPruningModifier to showcase the main operations involved in the pruning process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "49401cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26ab0023",
   "metadata": {},
   "source": [
    "The following `EmpiricalBlockFisherInverse` class implements and holds the block-wise approximation of the inverse Hessian. The approximation is in the form of a dampened empirical Fisher information matrix:\n",
    "$$\n",
    "H_{\\mathcal{L}}(\\mathbf{w}) \\simeq \\widehat{\\mathbf{F}} (\\mathbf{w}) = \\lambda \\mathbf{I}_d + \\frac{1}{m} \\sum_{i=1}^{m} \\nabla \\mathcal{L}_i(\\mathbf{w}) \\nabla \\mathcal{L}^\\top_i(\\mathbf{w})\n",
    "$$\n",
    "Relying on the fact that this is a sum of rank-1 matrices, the Woodbury/Sherman-Morrison inversion formula can be utilized to exactly calculate the Fisher inverse. Unrolling the recursive formulation with $ \\widehat{\\mathbf{F}}^{-1}_0(\\mathbf{w}) = \\frac{1}{\\lambda} \\mathbf{I}_d$, we can obtain an iterative formula to exactly calculate the inverse of the empirical Fisher matrix as:\n",
    "$$\n",
    "\\widehat{\\mathbf{F}}^{-1}(\\mathbf{w}) = \\widehat{\\mathbf{F}}^{-1}_m(\\mathbf{w}) = \\frac{1}{\\lambda} \\mathbf{I}_d - \\sum_{i=1}^{m} \\frac{\\left(\\widehat{\\mathbf{F}}^{-1}_{i-1}(\\mathbf{w}) \\nabla \\mathcal{L}_i(\\mathbf{w})\\right)\\left(\\widehat{\\mathbf{F}}^{-1}_{i-1}(\\mathbf{w}) \\nabla \\mathcal{L}_i(\\mathbf{w})\\right)^\\top}{m + \\nabla \\mathcal{L}_i^\\top(\\mathbf{w}) \\widehat{\\mathbf{F}}^{-1}_{i-1}(\\mathbf{w}) \\nabla \\mathcal{L}_i(\\mathbf{w})}\n",
    "$$\n",
    "\n",
    "This is implemented via the `add_grad` method, which efficiently updates the inverse with a new gradient.\n",
    "\n",
    "`diag` fetches the diagonal of the inverse Fisher, which is used in calculations of the saliency score $\\rho$ and of the optimal weight update $\\delta \\mathbf{w}$.\n",
    "`mul` efficiently computes matrix-vector products between a given vector `v` and the block-wise inverse Fisher matrix, which is used to calculate the optimal weight update $\\delta \\mathbf{w}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fb0a4e38",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EmpiricalBlockFisherInverse:\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_grads: int,\n",
    "        fisher_block_size: int,\n",
    "        num_weights: int,\n",
    "        damp: float,\n",
    "        device: torch.device,\n",
    "    ):\n",
    "        self.m = num_grads\n",
    "        self.B = fisher_block_size\n",
    "        self.d = num_weights\n",
    "        self.damp = damp\n",
    "        self.dev = device\n",
    "\n",
    "        self.num_blocks = math.ceil(self.d / self.B)\n",
    "        self.F_inv = (\n",
    "            (1.0 / self.damp * torch.eye(n=self.B, device=self.dev))\n",
    "            .unsqueeze(0)\n",
    "            .repeat(self.num_blocks, 1, 1)\n",
    "        )  # takes O(d x B) memory on a device\n",
    "\n",
    "    def add_grad(self, g: Tensor):\n",
    "        \"\"\"\n",
    "        Updates empirical Fisher inverse with a new gradient\n",
    "        :param g: a collected gradient\n",
    "        \"\"\"\n",
    "        # if 'd / B' is not integer, pad with zeros for batch calculations\n",
    "        if g.numel() < self.num_blocks * self.B:\n",
    "            g = torch.cat(\n",
    "                [g, torch.zeros(self.num_blocks * self.B - g.numel(), device=g.device)]\n",
    "            )\n",
    "\n",
    "        # prepare grad for batch calculations\n",
    "        g = g.view(self.num_blocks, self.B)\n",
    "\n",
    "        # batched F_inv x g: (batch, B, B) x (batch, B) -> (batch, B)\n",
    "        Finv_g = torch.einsum(\"bij,bj->bi\", self.F_inv, g)\n",
    "\n",
    "        # scalar denominator for each batch: (batch)\n",
    "        alpha = (self.m + torch.einsum(\"bi,bi->b\", g, Finv_g)).sqrt().unsqueeze(1)\n",
    "        Finv_g /= alpha\n",
    "\n",
    "        # update F_inv with new outer product: (batch, B) x (batch, B) -> (batch, B, B)\n",
    "        self.F_inv.baddbmm_(Finv_g.unsqueeze(2), Finv_g.unsqueeze(1), alpha=-1)\n",
    "\n",
    "    def diag(self) -> Tensor:\n",
    "        \"\"\"\n",
    "        :return: diagonal of the Fisher inverse matrix\n",
    "        \"\"\"\n",
    "        return self.F_inv.diagonal(dim1=1, dim2=2).flatten()[: self.d]\n",
    "\n",
    "    def mul(self, v: Tensor) -> Tensor:\n",
    "        \"\"\"\n",
    "        Computes matrix-vector product of the Fisher inverse matrix and a vector\n",
    "        :param v: a vector to compute matrix-vector product with\n",
    "        :return: result of the matrix-vector multiplication\n",
    "        \"\"\"\n",
    "        if v.numel() < self.num_blocks * self.B:\n",
    "            v = torch.cat(\n",
    "                [v, torch.zeros(self.num_blocks * self.B - v.numel(), device=v.device)]\n",
    "            )\n",
    "        return torch.bmm(\n",
    "            self.F_inv, v.view(self.num_blocks, self.B).unsqueeze_(2)\n",
    "        ).flatten()[: self.d]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27f7f1dc",
   "metadata": {},
   "source": [
    "Now, we define a dummy neural-network model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "de7dd26d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:7')\n",
    "d = 1000                          # number of prunable weights\n",
    "w = torch.rand(d, device=device)  # dummy weights\n",
    "target_sparsity = 0.7             # [0, 1.] range"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68c6975d",
   "metadata": {},
   "source": [
    "Now, we specify oBERT pruning hyper-parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e67be21f",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 100           # number of gradients\n",
    "B = 50            # block size\n",
    "lambd = 1e-7      # dampening\n",
    "\n",
    "# initialize Fisher inverse, occupies O(Bd) memory\n",
    "# for example: d=85_000_000, B=50 -> 85_000_000 * 50 * 4 / 1024^3 = 16GB\n",
    "fisher_inv = EmpiricalBlockFisherInverse(m, B, d, lambd, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a548e2b",
   "metadata": {},
   "source": [
    "Now, we collect `m` gradients used to approximate the Fisher inverse:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b6c25084",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fisher inverse updated with 1 gradients\r",
      "Fisher inverse updated with 2 gradients\r",
      "Fisher inverse updated with 3 gradients\r",
      "Fisher inverse updated with 4 gradients\r",
      "Fisher inverse updated with 5 gradients\r",
      "Fisher inverse updated with 6 gradients\r",
      "Fisher inverse updated with 7 gradients\r",
      "Fisher inverse updated with 8 gradients\r",
      "Fisher inverse updated with 9 gradients\r",
      "Fisher inverse updated with 10 gradients\r",
      "Fisher inverse updated with 11 gradients\r",
      "Fisher inverse updated with 12 gradients\r",
      "Fisher inverse updated with 13 gradients\r",
      "Fisher inverse updated with 14 gradients\r",
      "Fisher inverse updated with 15 gradients\r",
      "Fisher inverse updated with 16 gradients\r",
      "Fisher inverse updated with 17 gradients\r",
      "Fisher inverse updated with 18 gradients\r",
      "Fisher inverse updated with 19 gradients\r",
      "Fisher inverse updated with 20 gradients\r",
      "Fisher inverse updated with 21 gradients\r",
      "Fisher inverse updated with 22 gradients\r",
      "Fisher inverse updated with 23 gradients\r",
      "Fisher inverse updated with 24 gradients\r",
      "Fisher inverse updated with 25 gradients\r",
      "Fisher inverse updated with 26 gradients\r",
      "Fisher inverse updated with 27 gradients\r",
      "Fisher inverse updated with 28 gradients\r",
      "Fisher inverse updated with 29 gradients\r",
      "Fisher inverse updated with 30 gradients\r",
      "Fisher inverse updated with 31 gradients\r",
      "Fisher inverse updated with 32 gradients\r",
      "Fisher inverse updated with 33 gradients\r",
      "Fisher inverse updated with 34 gradients\r",
      "Fisher inverse updated with 35 gradients\r",
      "Fisher inverse updated with 36 gradients\r",
      "Fisher inverse updated with 37 gradients\r",
      "Fisher inverse updated with 38 gradients\r",
      "Fisher inverse updated with 39 gradients\r",
      "Fisher inverse updated with 40 gradients\r",
      "Fisher inverse updated with 41 gradients\r",
      "Fisher inverse updated with 42 gradients\r",
      "Fisher inverse updated with 43 gradients\r",
      "Fisher inverse updated with 44 gradients\r",
      "Fisher inverse updated with 45 gradients\r",
      "Fisher inverse updated with 46 gradients\r",
      "Fisher inverse updated with 47 gradients\r",
      "Fisher inverse updated with 48 gradients\r",
      "Fisher inverse updated with 49 gradients\r",
      "Fisher inverse updated with 50 gradients\r",
      "Fisher inverse updated with 51 gradients\r",
      "Fisher inverse updated with 52 gradients\r",
      "Fisher inverse updated with 53 gradients\r",
      "Fisher inverse updated with 54 gradients\r",
      "Fisher inverse updated with 55 gradients\r",
      "Fisher inverse updated with 56 gradients\r",
      "Fisher inverse updated with 57 gradients\r",
      "Fisher inverse updated with 58 gradients\r",
      "Fisher inverse updated with 59 gradients\r",
      "Fisher inverse updated with 60 gradients\r",
      "Fisher inverse updated with 61 gradients\r",
      "Fisher inverse updated with 62 gradients\r",
      "Fisher inverse updated with 63 gradients\r",
      "Fisher inverse updated with 64 gradients\r",
      "Fisher inverse updated with 65 gradients\r",
      "Fisher inverse updated with 66 gradients\r",
      "Fisher inverse updated with 67 gradients\r",
      "Fisher inverse updated with 68 gradients\r",
      "Fisher inverse updated with 69 gradients\r",
      "Fisher inverse updated with 70 gradients\r",
      "Fisher inverse updated with 71 gradients\r",
      "Fisher inverse updated with 72 gradients\r",
      "Fisher inverse updated with 73 gradients\r",
      "Fisher inverse updated with 74 gradients\r",
      "Fisher inverse updated with 75 gradients\r",
      "Fisher inverse updated with 76 gradients\r",
      "Fisher inverse updated with 77 gradients\r",
      "Fisher inverse updated with 78 gradients\r",
      "Fisher inverse updated with 79 gradients\r",
      "Fisher inverse updated with 80 gradients\r",
      "Fisher inverse updated with 81 gradients\r",
      "Fisher inverse updated with 82 gradients\r",
      "Fisher inverse updated with 83 gradients\r",
      "Fisher inverse updated with 84 gradients\r",
      "Fisher inverse updated with 85 gradients\r",
      "Fisher inverse updated with 86 gradients\r",
      "Fisher inverse updated with 87 gradients\r",
      "Fisher inverse updated with 88 gradients\r",
      "Fisher inverse updated with 89 gradients\r",
      "Fisher inverse updated with 90 gradients\r",
      "Fisher inverse updated with 91 gradients\r",
      "Fisher inverse updated with 92 gradients\r",
      "Fisher inverse updated with 93 gradients\r",
      "Fisher inverse updated with 94 gradients\r",
      "Fisher inverse updated with 95 gradients\r",
      "Fisher inverse updated with 96 gradients\r",
      "Fisher inverse updated with 97 gradients\r",
      "Fisher inverse updated with 98 gradients\r",
      "Fisher inverse updated with 99 gradients\r",
      "Fisher inverse updated with 100 gradients\r\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(m):\n",
    "    grad = torch.rand(d, device=device)  # a dummy gradient\n",
    "    fisher_inv.add_grad(grad)\n",
    "    print(f\"Fisher inverse updated with {i+1} gradients\", end=\"\\r\")\n",
    "print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10ab78cd",
   "metadata": {},
   "source": [
    "Now, we calculate saliency scores for each weight $j \\in \\{1, 2, 3, \\dots, d\\}$ in the form:\n",
    "$$\n",
    "\\rho_j = \\frac{w_j^2}{2 \\widehat{\\mathbf{F}}^{-1}_{j,j}}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bba1eec9",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = (w**2) / (2.0 * fisher_inv.diag())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bd1af7b",
   "metadata": {},
   "source": [
    "Now, we prune `target_sparsity * d` weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2820ab8c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pruned model's sparsity = 0.699999988079071\n"
     ]
    }
   ],
   "source": [
    "# find pruning threshold\n",
    "kth_score = torch.kthvalue(scores, round(target_sparsity * d))[0]\n",
    "\n",
    "# prune (i.e. set masks)\n",
    "mask = scores > kth_score\n",
    "print(f\"Pruned model's sparsity = {1 - torch.sum(mask)/mask.numel()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94145d73",
   "metadata": {},
   "source": [
    "Besides pruning a weight $w_j$, the OBS framework updates the unpruned weights to compensate for the loss incurred by pruning. The optimal weight update, which prunes the weight $w_j$ and updates the remaining ones, is given by:\n",
    "$$\n",
    "\\delta\\mathbf{w}_j = -\\frac{w_j}{\\widehat{\\mathbf{F}}^{-1}_{j,j}}\\widehat{\\mathbf{F}}^{-1} \\mathbf{e}_j\n",
    "$$\n",
    "As described in the paper, due to the intractable combinatorial complexity when pruning multiple weights at once, we have to manually zero-out the pruned weights as they can be perturbed to a non-zero value by the optimal weight update coming from other pruned weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cc19c746",
   "metadata": {},
   "outputs": [],
   "source": [
    "w -= fisher_inv.mul(w * (mask == 0) / fisher_inv.diag())\n",
    "w[mask == 0] = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b605fa8f",
   "metadata": {},
   "source": [
    "The 4-block oBERT pruning follows the same procedure, except that it implements a slightly different scoring and the optimal weight update equations, which can be found in the paper and in the [SparseML integration](https://github.com/neuralmagic/sparseml/blob/main/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
